In [1]:
import numpy as np
import warnings
warnings.filterwarnings('ignore')
import glob
from PIL import Image
import os
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import torch
from torch import nn, optim
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torchvision.datasets as dset
import torchvision.utils as vutils
from torchvision.utils import make_grid
import pandas as pd
from IPython.display import HTML
from tqdm.auto import tqdm
from torchvision.models import inception_v3
from torch.cuda.amp import GradScaler, autocast

def is_cuda():
    if torch.cuda.is_available():
        print("CUDA available")
        return "cuda"
    else:
        print("No CUDA. Working on CPU.")
        return "cpu"
        
device = is_cuda()
CUDA available
In [2]:
root = "../input/tomjerrysc/"
batch_size = 8
image_size = 256
nc = 3 # n channels
nz = 512 # n latent dim
ngf = 64 # size of generator feature map
ndf = 64 # size of discriminator feature map
lr = 0.0001
beta1 = 0.05
ngpu = 1
In [3]:
def show_tensor_images(image_tensor, num_images=8, size=(3, 64, 64), nrow=4, figsize=8):

    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=nrow)
    plt.figure(figsize=(figsize, figsize))
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()
    
def to_rgb(img):
    rgb_img = Image.new("RGB", img.size)
    rgb_img.paste(img)
    return rgb_img

class ImageSet(Dataset):
    def __init__(self, root, transform):
        self.root = root
        self.transform = transform
        self.imgs = sorted(glob.glob(os.path.join(root, "*.*")))
        
    def __getitem__(self, index):
        img = Image.open(self.imgs[index % len(self.imgs)])
        img = to_rgb(img)
        img = self.transform(img)
        return img
    
    def __len__(self):
        return len(self.imgs)

transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])


dataset = ImageSet(root=root, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
In [4]:
real_batch = next(iter(dataloader))
show_tensor_images(real_batch)
In [5]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
In [6]:
inception = inception_v3(pretrained=True) # For computation of FID score

class SLE(nn.Module):
    def __init__(self, in_channel):
        super().__init__()
        self.block = nn.Sequential(
            nn.AdaptiveAvgPool2d(output_size=4),
            nn.Conv2d(in_channel, in_channel, 4, 1, 0),
            nn.LeakyReLU(0.1),
            nn.Conv2d(in_channel, in_channel//8, 1, 1, 0),
            nn.Sigmoid()
        )

    def forward(self, high, low):
        x = self.block(low)
        return high * x

def make_noise(n_samples=batch_size, z_dim=256, device="cuda"):
    noise = torch.randn(n_samples, 256, device=device)
    return noise[:,:,None,None]

class Generator(nn.Module):
    def __init__(self, z_dim=256, out_res=256):
        super().__init__()
        assert out_res == 256, "Only Output Resolution of 256x256 Implemented, got {}".format(out_res)
        self.block1 = nn.Sequential(
            nn.ConvTranspose2d(z_dim, z_dim, 4, 1, 0),
            nn.BatchNorm2d(z_dim),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(z_dim, 2*z_dim, 3, 1, 1),
            nn.BatchNorm2d(2*z_dim),
            nn.ReLU()
        )
        self.block2 = self.make_block(2*z_dim, z_dim)
        self.block3 = self.make_block(z_dim, z_dim//2)
        self.block4 = self.make_block(z_dim//2, z_dim//4)
        self.block5 = self.make_block(z_dim//4, z_dim//4)
        self.block6 = self.make_block(z_dim//4, z_dim//8)
        self.out = nn.Sequential(
            nn.Conv2d(z_dim//8, 3, 3, 1, 1),
            nn.Tanh()
        )

        self.SLE1 = SLE(512)
        self.SLE2 = SLE(256)


    def make_block(self, in_channel, out_channel):
        block = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(in_channel, out_channel, 3, 1, 1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU()
        )
        return block

    def forward(self, x):
        h1 = self.block1(x)     # 512 x 8 x 8
        h2 = self.block2(h1)    # 256 x 16 x 16
        x = self.block3(h2)     # 128 x 32 x 32
        x = self.block4(x)      # 64 x 64 x 64
        x = self.block5(x)      # 64 x 128 x 128
        x = self.SLE1(x, h1)    # 64 x 128 x 128
        x = self.block6(x)      # 32 x 256 x 256
        x = self.SLE2(x, h2)    # 32 x 256 x 256
        x = self.out(x)         # 3 x 256 x 256
        return x

class Decoder(nn.Module):
    def __init__(self, in_feature=32):
        super().__init__()
        self.in_feature = in_feature
        g = []
        for _ in range(3):
            g += [self.make_block(in_feature)]
        g += [self.make_block(3)]
        self.decoder = nn.Sequential(*g)

    def make_block(self, out_feature):
        block = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(self.in_feature, out_feature, 3, 1, 1),
            nn.BatchNorm2d(out_feature),
            nn.ReLU()
        )
        return block

    def forward(self, x):
        return self.decoder(x)

class Discriminator(nn.Module):
    def __init__(self, hidden_dim=64, in_res=256):
        super().__init__()
        assert in_res == 256, "Only Output Resolution of 256x256 Implemented, got {}".format(in_res)
        self.block1 = nn.Sequential(
            nn.Conv2d(3, hidden_dim//2, 4, 2, 1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(hidden_dim//2, hidden_dim, 4, 2, 1),
            nn.BatchNorm2d(hidden_dim),
            nn.LeakyReLU(0.1)
        )
        self.block2 = self.make_block(hidden_dim, hidden_dim)
        self.skip2 = self.down_sample(hidden_dim, hidden_dim)
        self.block3 = self.make_block(hidden_dim, hidden_dim//2)
        self.skip3 = self.down_sample(hidden_dim, hidden_dim//2)
        self.block4 = self.make_block(hidden_dim//2, hidden_dim//2)
        self.skip4 = self.down_sample(hidden_dim//2, hidden_dim//2)
        self.out = nn.Sequential(
            nn.Conv2d(hidden_dim//2, hidden_dim//4, 1, 1, 0),
            nn.BatchNorm2d(hidden_dim//4),
            nn.LeakyReLU(0.1),
            nn.Conv2d(hidden_dim//4, 1, 4, 1, 0)
        )
        self.decoder1 = Decoder()
        self.decoder2 = Decoder()

    def make_recon(self, recon=True):
        self.recon = recon

    def forward(self, x):
        y = self.block1(x)
        y1 = self.block2(y)
        y2 = self.skip2(y)
        y = y1 + y2
        y1 = self.block3(y)
        y2 = self.skip3(y)
        h1 = y1 + y2        # 32 x 16 x 16 : For cropping
        y1 = self.block4(h1)
        y2 = self.skip4(h1)
        # Simply center crop for now, where the literature implemented random crop
        if len(h1.shape)==4:
            h1 = h1[:,:,4:12, 4:12]
        elif len(h1.shape)==3:
            h1 = h1[:, 4:12, 4:12]
        else:
            print("invalid shape for feature map to be cropped, {}".format(h1.shape))
        h2 = y1 + y2        # 32 x 8 x 8
        y = self.out(h2)    # 1 x 5 x 5
        if self.recon is True:
            y_part = self.decoder1(h1)
            y_recon = self.decoder2(h2)
            # y: 5 x 5 true/false
            # y_part: reconstructed image from center y_part
            # y_recon: reconstructed image from whole feature map
            return y, y_part, y_recon
        else:
            return y

    def make_block(self, in_channel, out_channel):
        block = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 4, 2, 1),
            nn.BatchNorm2d(out_channel),
            nn.LeakyReLU(0.1),
            nn.Conv2d(out_channel, out_channel, 3, 1, 1),
            nn.BatchNorm2d(out_channel),
            nn.LeakyReLU(0.1)
        )
        return block

    def down_sample(self, in_channel, out_channel):
        block = nn.Sequential(
            nn.AvgPool2d(2, 2),
            nn.Conv2d(in_channel, out_channel, 1, 1, 0),
            nn.BatchNorm2d(out_channel),
            nn.LeakyReLU(0.1)
        )
        return block
Downloading: "https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth" to /root/.cache/torch/hub/checkpoints/inception_v3_google-1a9a5a14.pth
In [7]:
def hinge_loss(output, real=True):
    return -torch.mean(torch.min(torch.zeros_like(output), -1+output)) if real else -torch.mean(torch.min(torch.zeros_like(output), -1-output))

def recon_loss(output, target):
    return torch.mean(torch.norm(output-target))

def gen_loss(output):
    return -torch.mean(output)

G = Generator()
G.apply(weights_init)
G.to(device)
D = Discriminator()
D.apply(weights_init)
D.to(device)
G_optim = optim.Adam(G.parameters())
D_optim = optim.Adam(D.parameters())
scaler1 = GradScaler()
scaler2 = GradScaler()

D_l, G_l = [], []
imgs_list = []
fixed_noise = make_noise()
cur_iter = 0
num_iters = 15000

while cur_iter < num_iters:
    for real in tqdm(dataloader):

        real = real.to(device)
        D_optim.zero_grad()
        D.make_recon(True)
        
        with autocast():
        
            D_real_pred, I_part, I_glob = D(real)
            D_real_loss = hinge_loss(D_real_pred)

            noise = make_noise()
            fake = G(noise)
            D.make_recon(False)
            D_fake_pred = D(fake)
            D_fake_loss = hinge_loss(D_fake_pred, real=False)

            if len(real.shape)==4:
                real_part = real[:,:,64:192, 64:192]
            elif len(real.shape)==3:
                real_part = real[:,64:192, 64:192]
            else:
                print("Invalid real shape, {}".format(real.shape))
            real_glob = F.interpolate(real, scale_factor=0.5)
            D_recon_loss = recon_loss(I_part, real_part) + recon_loss(I_glob, real_glob)
            D_loss = D_real_loss + D_fake_loss + D_recon_loss
            
        D_l.append(D_loss.item())
        scaler1.scale(D_loss).backward()
        scaler1.step(D_optim)
        scaler1.update()

        G_optim.zero_grad()
        
        with autocast():
            
            noise = make_noise()
            fake = G(noise)
            D_fake_pred = D(fake)
            G_loss = gen_loss(D_fake_pred)

        G_l.append(G_loss.item())
        scaler2.scale(G_loss).backward()
        scaler2.step(G_optim)
        scaler2.update()
        
        cur_iter += 1
        

        if (cur_iter) % 1000 == 0:
            print("{} / {}, D_loss: {:.4f}, G_loss: {:.4f}".format(cur_iter, num_iters, D_loss.item(), G_loss.item()))
            noise = make_noise()
            fake = G(noise)
            show_tensor_images(fake)
            imgs_list.append(G(fixed_noise).detach().cpu())
            
            torch.save(G.state_dict(), "G.pt")
            torch.save(D.state_dict(), "D.pt")
            
        del D_fake_pred, D_real_pred, I_part, I_glob, fake
        torch.cuda.empty_cache()

        
1000 / 15000, D_loss: 635.2388, G_loss: 0.0540
2000 / 15000, D_loss: 549.0581, G_loss: 0.7266
3000 / 15000, D_loss: 559.1874, G_loss: 2.1055
4000 / 15000, D_loss: 579.7663, G_loss: 2.6523
5000 / 15000, D_loss: 525.9560, G_loss: 0.7217
6000 / 15000, D_loss: 500.2514, G_loss: 2.6660
7000 / 15000, D_loss: 669.2898, G_loss: -0.0410
8000 / 15000, D_loss: 445.1941, G_loss: 2.0098
9000 / 15000, D_loss: 556.1753, G_loss: 1.3311
10000 / 15000, D_loss: 537.1878, G_loss: 0.4487
11000 / 15000, D_loss: 419.3483, G_loss: 1.2754
12000 / 15000, D_loss: 524.2166, G_loss: 3.9805
13000 / 15000, D_loss: 434.4963, G_loss: 3.9141
14000 / 15000, D_loss: 524.7975, G_loss: 3.0312
15000 / 15000, D_loss: 463.1813, G_loss: 0.4011
In [8]:
from torchvision.models import inception_v3
from torch.distributions import MultivariateNormal
import scipy
from scipy import linalg

inception.fc = nn.Identity()
inception.to(device)

# resnet = models.resnet50(pretrained=True)
def matrix_sqrt(x):
    y = x.cpu().detach().numpy()
    y = linalg.sqrtm(y)
    return torch.Tensor(y.real, device=x.device)

def frechet_distance(mu_x, mu_y, sig_x, sig_y):
    return torch.norm(mu_x-mu_y).pow(2) + torch.trace(sig_x+sig_y-2*matrix_sqrt(torch.matmul(sig_x, sig_y)))

def preprocess(img):
    return F.interpolate(img, size=(299,299), mode='bilinear', align_corners=False)

def get_cov(x):
    return torch.Tensor(np.cov(x.detach().numpy(), rowvar=False))
In [9]:
# fake_lst, real_lst = [], []
# G.eval()
# n_samples=10000
# batch_size=4
# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# cur_samples=0

# with torch.no_grad():

#     for real_example in tqdm(dataloader, total=n_samples // batch_size): # Go by batch
#         real_samples = preprocess(real_example)
#         real_features = inception(real_samples.to(device)) # Move features to CPU
#         real_lst.append(torch.Tensor(real_features[0].cpu()))

#         fake_samples = make_noise()
#         fake_samples = preprocess(G(fake_samples))
#         fake_features = inception(fake_samples.to(device))
#         fake_lst.append(torch.Tensor(fake_features[0].cpu()))
#         cur_samples += len(real_samples)
#         if cur_samples >= n_samples:
#             break
In [10]:
# fake_features_all = torch.cat(fake_lst)
# real_features_all = torch.cat(real_lst)

# mu_fake = fake_features_all.mean(0)
# mu_real = real_features_all.mean(0)
# sigma_fake = get_cov(fake_features_all)
# sigma_real = get_cov(real_features_all)

# with torch.no_grad():
#     print(frechet_distance(mu_real, mu_fake, sigma_real, sigma_fake).item())
In [11]:
def slerp(val, low, high):
    low_norm = low/torch.norm(low, dim=1, keepdim=True)
    high_norm = high/torch.norm(high, dim=1, keepdim=True)
    omega = torch.acos((low_norm*high_norm).sum(1))
    so = torch.sin(omega)
    res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
    return res
In [12]:
z1 = make_noise(n_samples=4)
z2 = make_noise(n_samples=4)

zs = torch.cat([slerp(v, z1, z2) for v in np.arange(0.1, 1, 0.1)])
zs = torch.cat([zs[4*k,:,:,:].unsqueeze(0) for k in range(9)]+[zs[4*k+1,:,:,:].unsqueeze(0) for k in range(9)]+
               [zs[4*k+2,:,:,:].unsqueeze(0) for k in range(9)]+[zs[4*k+3,:,:,:].unsqueeze(0) for k in range(9)])
In [13]:
show_tensor_images(G(zs), num_images=36, nrow=9, figsize=16)
In [14]:
# fig = plt.figure(figsize=(8,8))
# plt.axis("off")
# ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in imgs_list]
# ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

# HTML(ani.to_jshtml())
In [ ]: